from typing import Optional

import numpy as np
from numpy.linalg import eigh
from scipy.linalg import eigh
from scipy.sparse import issparse
from scipy.sparse.csgraph import laplacian
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

from init_SVT import init_SVT
from losses import compute_acc_ari_nmi
# from sklearn_extra.cluster import KMedoids
import numpy as np
from typing import Optional, List, Union

def to_one_hot(labels: np.ndarray, k: int) -> np.ndarray:
    labels = np.asarray(labels, dtype=int).ravel()
    N = labels.shape[0]
    one_hot = np.zeros((N, k), dtype=int)
    one_hot[np.arange(N), labels] = 1
    return one_hot

def _argmax_with_random_tie_breaking(scores: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    max_vals = scores.max(axis=1, keepdims=True)
    ties = (scores == max_vals)
    idx = np.empty(scores.shape[0], dtype=int)
    for i, row in enumerate(ties):
        candidates = np.where(row)[0]
        idx[i] = rng.choice(candidates)
    return idx

def local_refinement_by_neighbors(
    A: np.ndarray,
    pred_labels: np.ndarray,
    num_classes: int,
    random_state: Optional[Union[int, np.random.Generator]] = None,
    alpha: float = 1e-6,
    num_iters: int = 1,
    early_stop: bool = True,
    return_history: bool = False,
    verbose: bool = False,
) -> Union[np.ndarray, tuple]:
    """
    多轮邻域统计的局部细化（同步更新版）：
    - 每一轮用上一轮的标签作为输入；
    - 若标签不再变化且 early_stop=True，则提前停止；
    - 若 return_history=True，返回 (final_labels, history)。

    参数
    ----
    A : (N,N) 邻接矩阵（对称，允许权重；将自动去自环）
    pred_labels : (N,) 初始标签 0..num_classes-1
    num_classes : 社区数
    random_state : int 或 np.random.Generator
        用于并列随机打破；同一次调用里迭代保持可复现。
    alpha : 平滑系数，避免除零
    num_iters : 迭代次数上限（=1 等价于你原来的单次行为）
    early_stop : 若本轮与上一轮标签完全相同则提前停止
    return_history : 是否返回每轮标签
    verbose : 打印每轮变化的节点数

    返回
    ----
    final_labels 或 (final_labels, history)
    """
    # ---- 随机数发生器 ----
    if isinstance(random_state, np.random.Generator):
        rng = random_state
    else:
        rng = np.random.default_rng(random_state)

    # ---- 基本校验与预处理 ----
    A = np.asarray(A, dtype=float)
    if A.ndim == 3 and A.shape[0] == 1:
        A = A[0]
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError(f"A must be square (N,N), got {A.shape}")
    N = A.shape[0]

    # 去自环（保守）
    if np.any(np.diag(A) != 0):
        A = A.copy()
        np.fill_diagonal(A, 0.0)

    y = np.asarray(pred_labels).reshape(-1).astype(int)
    if y.shape[0] != N or y.min() < 0 or y.max() >= num_classes:
        raise ValueError("pred_labels shape/value error")

    history: List[np.ndarray] = []

    # ---- 迭代 ----
    for t in range(int(max(1, num_iters))):
        # 统计
        H = to_one_hot(y, num_classes)                 # (N,k)
        deg_to_comm = A @ H                              # (N,k)  邻域计数累到各社区
        comm_sizes = H.sum(axis=0, keepdims=True)        # (1,k)
        comm_sizes_excl_self = comm_sizes - H            # (N,k)  每个节点所在社区排除自身后的规模

        # 概率打分（带平滑）
        probs = (deg_to_comm + alpha) / (comm_sizes_excl_self + 2.0 * alpha)  # (N,k)

        # 先常规选择（并列随机打破）
        refined = _argmax_with_random_tie_breaking(probs, rng)  # (N,)

        # 规则：若自己所在社区排除自己后为 0，则强制保持原标签
        own_sizes_excl = comm_sizes_excl_self[np.arange(N), y]
        stay_mask = (own_sizes_excl == 0)
        refined[stay_mask] = y[stay_mask]

        if return_history:
            history.append(refined.copy())

        changes = int(np.sum(refined != y))
        if verbose:
            print(f"[local_refine iter {t+1}] changes={changes}")
        if early_stop and changes == 0:
            y = refined
            break

        y = refined

    if return_history:
        return y, history
    return y


import numpy as np

def _kmedians(U, k, n_init=20, max_iter=300, tol=1e-6, random_state=42):
    """
    纯 NumPy 的 K-medians（L1 距离 + 坐标-wise 中位数作为中心）
    - 返回: labels (n,), centers (k, d)
    """
    rng = np.random.default_rng(random_state)
    n, d = U.shape
    best_labels, best_centers, best_inertia = None, None, np.inf

    def l1_inertia(X, centers, labels):
        # sum over clusters of L1 distances
        inertia = 0.0
        for j in range(k):
            mask = labels == j
            if not np.any(mask):
                continue
            inertia += np.abs(X[mask] - centers[j]).sum()
        return inertia

    for _ in range(n_init):
        # 初始化：随机选 k 个样本做中心（你也可以实现 L1 版的 k-means++）
        init_idx = rng.choice(n, size=k, replace=False)
        centers = U[init_idx].copy()

        for _ in range(max_iter):
            # 分配：L1 距离
            # 距离矩阵 (n, k)
            dists = np.abs(U[:, None, :] - centers[None, :, :]).sum(axis=2)
            labels = np.argmin(dists, axis=1)

            # 更新：坐标-wise median
            new_centers = centers.copy()
            for j in range(k):
                mask = labels == j
                if not np.any(mask):
                    # 空簇：重置到随机样本
                    new_centers[j] = U[rng.integers(0, n)]
                else:
                    # 每个维度取中位数（更鲁棒）
                    new_centers[j] = np.median(U[mask], axis=0)

            # 收敛判据（中心变化量）
            shift = np.abs(new_centers - centers).sum()
            centers = new_centers
            if shift <= tol:
                break

        inertia = l1_inertia(U, centers, labels)
        if inertia < best_inertia:
            best_inertia = inertia
            best_labels, best_centers = labels.copy(), centers.copy()

    return best_labels, best_centers


def _cluster_from_U(U, k):
    """
    用 K-medians（L1）替换原来的 KMeans。
    返回 labels (n,)
    """
    labels, _ = _kmedians(np.asarray(U, dtype=float), k=k, n_init=20, max_iter=300, tol=1e-6, random_state=42)
    return labels


def spectral_clustering_adj(A, k, true_labels, normalized: bool = False, *, run_all: bool = False,
                            random_state: int = 0):
    """
    一口气跑三种谱聚类（normalized / unnormalized / adjacency）+ local refinement + SVT初始化，
    同时兼容旧接口（默认返回与旧版一致的6个指标）。

    旧接口（与 test_single_first_period 兼容）：
        spectral_clustering_adj(A, k, true_labels, normalized=False)
        -> (acc_sc, ari_sc, nmi_sc, acc_ref, ari_ref, nmi_ref)

    新用法（一次性获取三种方法的结果）：
        spectral_clustering_adj(A, k, true_labels, run_all=True)
        -> {
            'normalized':   {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
            'unnormalized': {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
            'adjacency':    {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
            'svt':          {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
        }
    依赖：laplacian, eigh, KMeans, compute_acc_ari_nmi, local_refinement_by_neighbors, init_SVT
    """

    # --- 输入规范化 ---
    # 支持 torch.Tensor / 稀疏 / (1,N,N) / (N,N)
    try:
        import torch
        if isinstance(A, torch.Tensor):
            A = A.detach().cpu().numpy()
        if isinstance(true_labels, torch.Tensor):
            true_labels = true_labels.detach().cpu().numpy()
    except Exception:
        pass

    if A.ndim == 3 and A.shape[0] == 1:
        A = A.squeeze(0)
    if issparse(A):
        A = A.toarray()
    A = np.asarray(A, dtype=np.float64)

    # 确保对称
    A = 0.5 * (A + A.T)

    def _run_normalized():
        L = laplacian(A, normed=True).astype(np.float64)
        w, V = eigh(L)  # 取最小 k 个
        U = V[:, np.argsort(w)[:k]]
        # Ng–Jordan–Weiss 行归一化
        U = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)
        labels = _cluster_from_U(U, k)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)
        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)
        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    def _run_unnormalized():
        L = laplacian(A, normed=False).astype(np.float64)
        w, V = eigh(L)  # 取最小 k 个
        U = V[:, np.argsort(w)[:k]]
        labels = _cluster_from_U(U, k)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)
        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)
        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    def _run_adjacency():
        # 全谱分解（稠密）
        w, V = eigh(A)

        # 按 |λ| 从大到小取前 k 个
        idx = np.argsort(np.abs(w))[-k:]
        idx = idx[np.argsort(np.abs(w[idx]))[::-1]]
        U = V[:, idx]

        # 后续聚类 + 评估
        labels = _cluster_from_U(U, k)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)

        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)

        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    def _run_svt():
        """使用SVT初始化进行聚类"""
        # 使用SVT获取初始潜在向量
        Z0, alpha0 = init_SVT(A, k)

        # 对潜在向量进行聚类
        labels = _cluster_from_U(Z0, k)
        # print(labels)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)

        # 局部优化
        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)

        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    if run_all:
        sc_n, ref_n = _run_normalized()
        sc_u, ref_u = _run_unnormalized()
        sc_a, ref_a = _run_adjacency()
        sc_svt, ref_svt = _run_svt()  # 添加SVT方法
        return {
            'normalized': {'sc': sc_n, 'refined': ref_n},
            'unnormalized': {'sc': sc_u, 'refined': ref_u},
            'adjacency': {'sc': sc_a, 'refined': ref_a},
            'svt': {'sc': sc_svt, 'refined': ref_svt},  # 添加SVT结果
        }
    else:
        # 兼容旧接口：按 normalized 标志选择单一路径并返回 6 元组
        if normalized:
            sc, ref = _run_normalized()
        else:
            sc, ref = _run_unnormalized()
        acc_sc, ari_sc, nmi_sc = sc
        acc_ref, ari_ref, nmi_ref = ref
        return acc_sc, ari_sc, nmi_sc, acc_ref, ari_ref, nmi_ref
